Explaining Text Classification
from explainer.explainers import feature_attributions_explainer, metrics_explainer
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['KMP_WARNINGS'] = 'off'
import numpy as np
from sklearn import datasets
all_categories = ['alt.atheism','comp.graphics','comp.os.ms-windows.misc','comp.sys.ibm.pc.hardware',
'comp.sys.mac.hardware','comp.windows.x', 'misc.forsale','rec.autos','rec.motorcycles',
'rec.sport.baseball','rec.sport.hockey','sci.crypt','sci.electronics','sci.med',
'sci.space','soc.religion.christian','talk.politics.guns','talk.politics.mideast',
'talk.politics.misc','talk.religion.misc']
selected_categories = ['alt.atheism','comp.graphics','rec.motorcycles','sci.space','talk.politics.misc']
X_train_text, Y_train = datasets.fetch_20newsgroups(subset="train", categories=selected_categories, return_X_y=True)
X_test_text , Y_test = datasets.fetch_20newsgroups(subset="test", categories=selected_categories, return_X_y=True)
X_train_text = np.array(X_train_text)
X_test_text = np.array(X_test_text)
classes = np.unique(Y_train)
mapping = dict(zip(classes, selected_categories))
len(X_train_text), len(X_test_text), classes, mapping
(2720,
1810,
array([0, 1, 2, 3, 4]),
{0: 'alt.atheism',
1: 'comp.graphics',
2: 'rec.motorcycles',
3: 'sci.space',
4: 'talk.politics.misc'})
print(Y_test)
[2 3 0 ... 3 2 3]
Vectorize Text Data
import sklearn
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
vectorizer = TfidfVectorizer(max_features=50000)
vectorizer.fit(np.concatenate((X_train_text, X_test_text)))
X_train = vectorizer.transform(X_train_text)
X_test = vectorizer.transform(X_test_text)
X_train, X_test = X_train.toarray(), X_test.toarray()
X_train.shape, X_test.shape
((2720, 50000), (1810, 50000))
Define the Model
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
def create_model():
return Sequential([
layers.Input(shape=X_train.shape[1:]),
layers.Dense(128, activation="relu"),
layers.Dense(64, activation="relu"),
layers.Dense(len(classes), activation="softmax"),
])
model = create_model()
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 128) 6400128
dense_1 (Dense) (None, 64) 8256
dense_2 (Dense) (None, 5) 325
=================================================================
Total params: 6,408,709
Trainable params: 6,408,709
Non-trainable params: 0
_________________________________________________________________
Compile and Train Model
model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])
history = model.fit(X_train, Y_train, batch_size=256, epochs=5, validation_data=(X_test, Y_test))
Evaluate Model Performance
from sklearn.metrics import accuracy_score, classification_report
train_preds = model.predict(X_train)
test_preds = model.predict(X_test)
print("Train Accuracy : {:.3f}".format(accuracy_score(Y_train, np.argmax(train_preds, axis=1))))
print("Test Accuracy : {:.3f}".format(accuracy_score(Y_test, np.argmax(test_preds, axis=1))))
print("\nClassification Report : ")
print(classification_report(Y_test, np.argmax(test_preds, axis=1), target_names=selected_categories))
Show code cell output
1/85 [..............................] - ETA: 10s
7/85 [=>............................] - ETA: 0s
19/85 [=====>........................] - ETA: 0s
32/85 [==========>...................] - ETA: 0s
41/85 [=============>................] - ETA: 0s
51/85 [=================>............] - ETA: 0s
63/85 [=====================>........] - ETA: 0s
76/85 [=========================>....] - ETA: 0s
85/85 [==============================] - 1s 5ms/step
1/57 [..............................] - ETA: 5s
11/57 [====>.........................] - ETA: 0s
21/57 [==========>...................] - ETA: 0s
32/57 [===============>..............] - ETA: 0s
44/57 [======================>.......] - ETA: 0s
56/57 [============================>.] - ETA: 0s
57/57 [==============================] - 0s 5ms/step
Train Accuracy : 1.000
Test Accuracy : 0.949
Classification Report :
precision recall f1-score support
alt.atheism 0.98 0.91 0.94 319
comp.graphics 0.94 0.96 0.95 389
rec.motorcycles 0.98 0.98 0.98 398
sci.space 0.93 0.94 0.93 394
talk.politics.misc 0.92 0.95 0.93 310
accuracy 0.95 1810
macro avg 0.95 0.95 0.95 1810
weighted avg 0.95 0.95 0.95 1810
# one-hot-encode clasess
oh_Y_test = np.eye(len(classes))[Y_test]
cm = metrics_explainer['confusionmatrix'](oh_Y_test, test_preds, selected_categories)
cm.visualize()
print(cm.report)
precision recall f1-score support
alt.atheism 0.98 0.91 0.94 319
comp.graphics 0.94 0.96 0.95 389
rec.motorcycles 0.98 0.98 0.98 398
sci.space 0.93 0.94 0.93 394
talk.politics.misc 0.92 0.95 0.93 310
accuracy 0.95 1810
macro avg 0.95 0.95 0.95 1810
weighted avg 0.95 0.95 0.95 1810
plotter = metrics_explainer['plot'](oh_Y_test, test_preds, selected_categories)
plotter.pr_curve()
plotter.roc_curve()
import re
X_batch_text = X_test_text[1:3]
X_batch = X_test[1:3]
print("Samples : ")
for text in X_batch_text:
print(re.split(r"\W+", text))
print()
preds_proba = model.predict(X_batch)
preds = preds_proba.argmax(axis=1)
print("Actual Target Values : {}".format([selected_categories[target] for target in Y_test[1:3]]))
print("Predicted Target Values : {}".format([selected_categories[target] for target in preds]))
print("Predicted Probabilities : {}".format(preds_proba.max(axis=1)))
Samples :
['From', 'prb', 'access', 'digex', 'net', 'Pat', 'Subject', 'Re', 'Near', 'Miss', 'Asteroids', 'Q', 'Organization', 'Express', 'Access', 'Online', 'Communications', 'Greenbelt', 'MD', 'USA', 'Lines', '4', 'Distribution', 'sci', 'NNTP', 'Posting', 'Host', 'access', 'digex', 'net', 'TRry', 'the', 'SKywatch', 'project', 'in', 'Arizona', 'pat', '']
['From', 'cobb', 'alexia', 'lis', 'uiuc', 'edu', 'Mike', 'Cobb', 'Subject', 'Science', 'and', 'theories', 'Organization', 'University', 'of', 'Illinois', 'at', 'Urbana', 'Lines', '19', 'As', 'per', 'various', 'threads', 'on', 'science', 'and', 'creationism', 'I', 've', 'started', 'dabbling', 'into', 'a', 'book', 'called', 'Christianity', 'and', 'the', 'Nature', 'of', 'Science', 'by', 'JP', 'Moreland', 'A', 'question', 'that', 'I', 'had', 'come', 'from', 'one', 'of', 'his', 'comments', 'He', 'stated', 'that', 'God', 'is', 'not', 'necessarily', 'a', 'religious', 'term', 'but', 'could', 'be', 'used', 'as', 'other', 'scientific', 'terms', 'that', 'give', 'explanation', 'for', 'events', 'or', 'theories', 'without', 'being', 'a', 'proven', 'scientific', 'fact', 'I', 'think', 'I', 'got', 'his', 'point', 'I', 'can', 'quote', 'the', 'section', 'if', 'I', 'm', 'being', 'vague', 'The', 'examples', 'he', 'gave', 'were', 'quarks', 'and', 'continental', 'plates', 'Are', 'there', 'explanations', 'of', 'science', 'or', 'parts', 'of', 'theories', 'that', 'are', 'not', 'measurable', 'in', 'and', 'of', 'themselves', 'or', 'can', 'everything', 'be', 'quantified', 'measured', 'tested', 'etc', 'MAC', 'Michael', 'A', 'Cobb', 'and', 'I', 'won', 't', 'raise', 'taxes', 'on', 'the', 'middle', 'University', 'of', 'Illinois', 'class', 'to', 'pay', 'for', 'my', 'programs', 'Champaign', 'Urbana', 'Bill', 'Clinton', '3rd', 'Debate', 'cobb', 'alexia', 'lis', 'uiuc', 'edu', 'Nobody', 'can', 'explain', 'everything', 'to', 'anybody', 'G', 'K', 'Chesterton', '']
1/1 [==============================] - ETA: 0s
1/1 [==============================] - 0s 38ms/step
Actual Target Values : ['sci.space', 'alt.atheism']
Predicted Target Values : ['sci.space', 'alt.atheism']
Predicted Probabilities : [0.9238798 0.75361186]
SHAP Partition Explainer
Visualize SHAP Values Correct Predictions
def make_predictions(X_batch_text):
X_batch = vectorizer.transform(X_batch_text).toarray()
preds = model.predict(X_batch)
return preds
partition_explainer = feature_attributions_explainer.partitionexplainer(make_predictions, r"\W+", selected_categories)(X_batch_text)
Text Plot
partition_explainer.visualize()
[0]
outputs
alt.atheism
comp.graphics
rec.motorcycles
sci.space
talk.politics.misc
inputs
From:
prb@
access.
digex.
net (
Pat)
Subject:
Re:
Near
Miss
Asteroids (
Q)
Organization:
Express
Access
Online
Communications,
Greenbelt,
MD
USA
Lines:
4
Distribution:
sci
NNTP-
Posting-
Host:
access.
digex.
net
TRry
the
SKywatch
project
in
Arizona.
pat
[1]
outputs
alt.atheism
comp.graphics
rec.motorcycles
sci.space
talk.politics.misc
inputs
From:
cobb@
alexia.
lis.
uiuc.edu (
Mike Cobb)
Subject: Science
and
theories
Organization:
University of
Illinois at
Urbana
Lines:
19
As
per various
threads on science
and
creationism, I'
ve started
dabbling into a
book called Christianity and
the Nature of Science
by JP Moreland. A
question
that I had come from one of
his comments. He stated
that God
is not
necessarily a religious term,
but
could be used as
other scientific terms that
give explanation for events
or
theories, without being a
proven scientific
fact. I
think I got his
point -- I can quote
the section if I'
m being vague.
The
examples he
gave were quarks
and continental plates. Are there
explanations of science or
parts of theories that
are not measurable in
and of
themselves, or can everything be quantified,
measured, tested, etc.?
MAC
--
****************************************************************
Michael
A.
Cobb
"...and I won'
t raise taxes on
the middle University of
Illinois
class to pay
for my programs." Champaign-
Urbana
-Bill
Clinton 3rd
Debate cobb@
alexia.lis.uiuc.
edu
Nobody can explain
everything to anybody. G.K.Chesterton
Bar Plots
Bar Plot 1
shap = partition_explainer.shap
shap_values = partition_explainer.shap_values
shap.plots.bar(partition_explainer.shap_values[:,:, selected_categories[preds[0]]].mean(axis=0), max_display=15,
order=shap.Explanation.argsort.flip)
Bar Plot 2
shap.plots.bar(shap_values[0,:, selected_categories[preds[0]]], max_display=15,
order=shap.Explanation.argsort.flip)
Bar Plot 3
shap.plots.bar(shap_values[:,:, selected_categories[preds[1]]].mean(axis=0), max_display=15,
order=shap.Explanation.argsort.flip)
Bar Plot 4
shap.plots.bar(shap_values[1,:, selected_categories[preds[1]]], max_display=15,
order=shap.Explanation.argsort.flip)
Waterfall Plots
Waterfall Plot 1
shap.waterfall_plot(shap_values[0][:, selected_categories[preds[0]]], max_display=15)
Waterfall Plot 2
shap.waterfall_plot(shap_values[1][:, selected_categories[preds[1]]], max_display=15)
Force Plot
import re
tokens = re.split("\W+", X_batch_text[0].lower())
shap.initjs()
shap.force_plot(shap_values.base_values[0][preds[0]], shap_values[0][:, preds[0]].values,
feature_names = tokens[:-1], out_names=selected_categories[preds[0]])
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.